{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import jsonlines\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import copy\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"The beauty item has following attributes: \\n name is <TITLE>; brand is <BRAND>; price is <PRICE>. \\n\"\n",
    "feat_template = \"The item has following features: <CATEGORIES>. \\n\"\n",
    "desc_template = \"The item has following descriptions: <DESCRIPTION>. \\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = json.load(open(\"./handled/item2attributes.json\", \"r\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_feats = []\n",
    "\n",
    "for user, user_attris in data.items():\n",
    "    for feat_name in user_attris.keys():\n",
    "        if feat_name not in all_feats:\n",
    "            all_feats.append(feat_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_attri(item_str, attri, item_info):\n",
    "\n",
    "    if attri not in item_info.keys():\n",
    "        new_str = item_str.replace(f\"<{attri.upper()}>\", \"unknown\")\n",
    "    else:\n",
    "        new_str = item_str.replace(f\"<{attri.upper()}>\", str(item_info[attri]))\n",
    "\n",
    "    return new_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_feat(item_str, feat, item_info):\n",
    "\n",
    "    if feat not in item_info.keys():\n",
    "        return \"\"\n",
    "    \n",
    "    assert isinstance(item_info[feat], list)\n",
    "    feat_str = \"\"\n",
    "    for meta_feat in item_info[feat][0]:\n",
    "        feat_str = feat_str + meta_feat + \"; \"\n",
    "    new_str = item_str.replace(f\"<{feat.upper()}>\", feat_str)\n",
    "\n",
    "    if len(new_str) > 2048: # avoid exceed the input length limitation\n",
    "        return new_str[:2048]\n",
    "\n",
    "    return new_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_data = {}\n",
    "for key, value in tqdm(data.items()):\n",
    "    item_str = copy.deepcopy(prompt_template)\n",
    "    item_str = get_attri(item_str, \"title\", value)\n",
    "    item_str = get_attri(item_str, \"brand\", value)\n",
    "    item_str = get_attri(item_str, \"date\", value)\n",
    "    item_str = get_attri(item_str, \"price\", value)\n",
    "\n",
    "    feat_str = copy.deepcopy(feat_template)\n",
    "    feat_str = get_feat(feat_str, \"categories\", value)\n",
    "    desc_str = copy.deepcopy(desc_template)\n",
    "    desc_str = get_attri(desc_str, \"description\", value)\n",
    "    \n",
    "    item_data[key] = item_str + feat_str + desc_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_data[\"1304351475\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(item_data, open(\"./handled/item_str.json\", \"w\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_data = json.load(open(\"./handled/item_str.json\", \"r\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "\n",
    "def save_data(data_path, data):\n",
    "    '''write all_data list to a new jsonl'''\n",
    "    with jsonlines.open(\"./handled/\"+ data_path, \"w\") as w:\n",
    "        for meta_data in data:\n",
    "            w.write(meta_data)\n",
    "\n",
    "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"item2id\"]\n",
    "json_data = []\n",
    "for key, value in item_data.items():\n",
    "    json_data.append({\"input\": value, \"target\": \"\", \"item\": key, \"item_id\": id_map[key]})\n",
    "\n",
    "save_data(\"item_str.jsonline\", json_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "\n",
    "url = \"\"\n",
    "\n",
    "payload = json.dumps({\n",
    "   \"model\": \"text-embedding-ada-002\",\n",
    "   \"input\": \"The food was delicious and the waiter...\"\n",
    "})\n",
    "headers = {\n",
    "   'Authorization': '',\n",
    "   'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n",
    "   'Content-Type': 'application/json'\n",
    "}\n",
    "\n",
    "response = requests.request(\"POST\", url, headers=headers, data=payload)\n",
    "\n",
    "print(response.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_response(prompt):\n",
    "    url = \"\"\n",
    "\n",
    "    payload = json.dumps({\n",
    "    \"model\": \"text-embedding-ada-002\",\n",
    "    \"input\": prompt\n",
    "    })\n",
    "    headers = {\n",
    "    'Authorization': '',\n",
    "    'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',\n",
    "    'Content-Type': 'application/json'\n",
    "    }\n",
    "\n",
    "    response = requests.request(\"POST\", url, headers=headers, data=payload)\n",
    "    re_json = json.loads(response.text)\n",
    "\n",
    "    return re_json[\"data\"][0][\"embedding\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "item_emb = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "value_list = []\n",
    "\n",
    "for key, value in tqdm(item_data.items()):\n",
    "    if len(value) > 4096:\n",
    "        value_list.append(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(\"./handled/item_emb.pkl\"):    # check whether some item emb exist in cache\n",
    "    item_emb = pickle.load(open(\"./handled/item_emb.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 1\n",
    "while 1:    # avoid broken due to internet connection\n",
    "    if len(item_emb) == len(item_data):\n",
    "        break\n",
    "    try:\n",
    "        for key, value in tqdm(item_data.items()):\n",
    "            if key not in item_emb.keys():\n",
    "                if len(value) > 4096:\n",
    "                    value = value[:4095]\n",
    "                item_emb[key] = get_response(value)\n",
    "                count += 1\n",
    "    except:\n",
    "        pickle.dump(item_emb, open(\"./handled/item_emb.pkl\", \"wb\"))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(item_emb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_map = json.load(open(\"./handled/id_map.json\", \"r\"))[\"id2item\"]\n",
    "emb_list = []\n",
    "for id in range(1, len(item_emb)+1):\n",
    "    meta_emb = item_emb[id_map[str(id)]]\n",
    "    emb_list.append(meta_emb)\n",
    "\n",
    "emb_list = np.array(emb_list)\n",
    "pickle.dump(emb_list, open(\"./handled/itm_emb_np.pkl\", \"wb\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
